/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.linalg

import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable, Long => JavaLong}
import java.util

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.apache.spark.SparkException

import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.{specialized => spec}
import scala.language.implicitConversions


/**
  * Represents a numeric vector, whose index type is Int and value type is Double.
  *
  * @note Users should not implement this interface.
  */
sealed trait Vector extends Serializable {
  def size: Long

  def toArray: Array[Double]

  override def equals(other: Any): Boolean = {
    other match {
      case v2: Vector =>
        if (this.size != v2.size) return false
        (this, v2) match {
          case (s1: IntSparseVector, s2: IntSparseVector) =>
            Vectors.equals(s1.indices, s1.values, s2.indices, s2.values)
          case (s1: IntSparseVector, d1: DenseVector) =>
            Vectors.equals(s1.indices, s1.values, 0 until d1.size.toInt, d1.values)
          case (d1: DenseVector, s1: IntSparseVector) =>
            Vectors.equals(0 until d1.size.toInt, d1.values, s1.indices, s1.values)
          case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
        }
      case _ => false
    }
  }

  override def hashCode(): Int = {
    // This is a reference implementation. It calls return in foreachActive, which is slow.
    // Subclasses should override it with optimized implementation.
    var result: Int = 31 + size.toInt
    var nnz = 0
    this.foreachActive { (index, value) =>
      if (nnz < Vectors.MAX_HASH_NNZ) {
        // ignore explicit 0 for comparison between sparse and dense
        if (value != 0) {
          result = 31 * result + index.toInt
          val bits = java.lang.Double.doubleToLongBits(value)
          result = 31 * result + (bits ^ (bits >>> 32)).toInt
          nnz += 1
        }
      } else {
        return result
      }
    }
    result
  }

  def asBreeze: BV[Double]

  def apply(i: Long): Double

  def copy: Vector = {
    throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
  }

  def foreachActive(f: (Long, Double) => Unit): Unit

  def numActives: Long

  def numNonzeros: Long

  def toSparse: SparseVector = toSparseWithSize(numNonzeros)

  private[linalg] def toSparseWithSize(nnz: Long): SparseVector

  def toDense: DenseVector

  def compressed: Vector = {
    val nnz = numNonzeros
    // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes.
    if (1.5 * (nnz + 1.0) < size) {
      toSparseWithSize(nnz)
    } else {
      toDense
    }
  }

  def argmax: Long
}

/**
  * Factory methods for [[Vector]].
  * We don't use the name `Vector` because Scala imports
  * `scala.collection.immutable.Vector` by default.
  */
object Vectors {
  /**
    * Creates a dense vector from its values.
    */
  // @varargs

  def dense(firstValue: Double, otherValues: Double*): Vector =
    new DenseVector((firstValue +: otherValues).toArray)

  // A dummy implicit is used to avoid signature collision with the one generated by @varargs.
  /**
    * Creates a dense vector from a double array.
    */

  def dense(values: Array[Double]): Vector = new DenseVector(values)

  /**
    * Creates a sparse vector providing its index array and value array.
    *
    * @param size    vector size.
    * @param indices index array, must be strictly increasing.
    * @param values  value array, must have the same length as indices.
    */

  def sparse[@spec(Int, Long) K: ClassTag](size: Long, indices: Array[K], values: Array[Double]): Vector =
    implicitly[ClassTag[K]].runtimeClass match {
      case intType if classOf[Int] == intType =>
        new IntSparseVector(size, indices.asInstanceOf[Array[Int]], values)
      case longType if classOf[Long] == longType =>
        new LongSparseVector(size, indices.asInstanceOf[Array[Long]], values)
    }

  /**
    * Creates a sparse vector using unordered (index, value) pairs.
    *
    * @param size     vector size.
    * @param elements vector elements in (index, value) pairs.
    */

  def sparse[@spec(Int, Long) K <% Ordered[K] : ClassTag](size: Long, elements: Seq[(K, Double)]): Vector = {
    val (indices, values) = elements.sortBy(_._1).unzip
    implicitly[ClassTag[K]].runtimeClass match {
      case intType if classOf[Int] == intType =>
        new IntSparseVector(size, indices.toArray.asInstanceOf[Array[Int]], values.toArray)
      case longType if classOf[Long] == longType =>
        new LongSparseVector(size, indices.toArray.asInstanceOf[Array[Long]], values.toArray)
    }
  }

  /**
    * Creates a sparse vector using unordered (index, value) pairs in a Java friendly way.
    *
    * @param size     vector size.
    * @param elements vector elements in (index, value) pairs.
    */

  def sparse[K: ClassTag](size: Long, elements: JavaIterable[(K, JavaDouble)]): Vector = {
    implicitly[ClassTag[K]].runtimeClass match {
      case intType if classOf[JavaInteger] == intType =>
        sparse(size, elements.asScala.map { case (i: JavaInteger, x) => (i.intValue(), x.doubleValue()) }.toSeq)
      case longType if classOf[JavaLong] == longType =>
        sparse(size, elements.asScala.map { case (i: JavaLong, x) => (i.longValue(), x.doubleValue()) }.toSeq)
    }
  }

  /**
    * Creates a vector of all zeros.
    *
    * @param size vector size
    * @return a zero vector
    */

  def zeros(size: Int): Vector = {
    new DenseVector(new Array[Double](size))
  }

  def parseNumeric(any: Any): Vector = {
    any match {
      case values: Array[Double] =>
        Vectors.dense(values)
      case Seq(size: Int, indices: Array[Int], values: Array[Double]) =>
        Vectors.sparse(size, indices, values)
      case Seq(size: Long, indices: Array[Long], values: Array[Double]) =>
        Vectors.sparse(size, indices, values)
      case Seq(size: Double, indices: Array[Double], values: Array[Double]) if size < Int.MaxValue =>
        Vectors.sparse(size.toInt, indices.map(_.toInt), values)
      case Seq(size: Double, indices: Array[Double], values: Array[Double]) if size > Int.MaxValue =>
        Vectors.sparse(size.toLong, indices.map(_.toLong), values)
      case other =>
        throw new SparkException(s"Cannot parse $other.")
    }
  }

  def parse(s: String): Vector = {
    parseNumeric(NumericParser.parse(s))
  }

  /**
    * Creates a vector instance from a breeze vector.
    */
  def fromBreeze(breezeVector: BV[Double]): Vector = {
    breezeVector match {
      case v: BDV[Double] =>
        if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) {
          new DenseVector(v.data)
        } else {
          new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one
        }
      case v: BSV[Double] =>
        if (v.index.length == v.used) {
          new IntSparseVector(v.length, v.index, v.data)
        } else {
          new IntSparseVector(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used))
        }
      case v: BV[_] =>
        sys.error("Unsupported Breeze vector type: " + v.getClass.getName)
    }
  }

  /**
    * Returns the p-norm of this vector.
    *
    * @param vector input vector.
    * @param p      norm.
    * @return norm in L^p^ space.
    */

  def norm(vector: Vector, p: Double): Double = {
    require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " +
      s"You specified p=$p.")
    val values = vector match {
      case DenseVector(vs) => vs
      case IntSparseVector(n, ids, vs) => vs
      case LongSparseVector(n, ids, vs) => vs
      case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
    }
    val size = values.length

    if (p == 1) {
      var sum = 0.0
      var i = 0
      while (i < size) {
        sum += math.abs(values(i))
        i += 1
      }
      sum
    } else if (p == 2) {
      var sum = 0.0
      var i = 0
      while (i < size) {
        sum += values(i) * values(i)
        i += 1
      }
      math.sqrt(sum)
    } else if (p == Double.PositiveInfinity) {
      var max = 0.0
      var i = 0
      while (i < size) {
        val value = math.abs(values(i))
        if (value > max) max = value
        i += 1
      }
      max
    } else {
      var sum = 0.0
      var i = 0
      while (i < size) {
        sum += math.pow(math.abs(values(i)), p)
        i += 1
      }
      math.pow(sum, 1.0 / p)
    }
  }

  /**
    * Returns the squared distance between two Vectors.
    *
    * @param v1 first Vector.
    * @param v2 second Vector.
    * @return squared distance between two Vectors.
    */

  def sqdist(v1: Vector, v2: Vector): Double = {
    require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" +
      s"=${v2.size}.")
    var squaredDistance = 0.0
    (v1, v2) match {
      case (v1: IntSparseVector, v2: IntSparseVector) =>
        squaredDistance = sqdist(v1.indices, v1.values, v2.indices, v2.values)
      case (v1: IntSparseVector, v2: DenseVector) =>
        squaredDistance = sqdist(v1.indices, v1.values, v2.values)
      case (v1: DenseVector, v2: IntSparseVector) =>
        squaredDistance = sqdist(v2, v1)
      case (DenseVector(vv1), DenseVector(vv2)) =>
        var kv = 0
        val sz = vv1.length
        while (kv < sz) {
          val score = vv1(kv) - vv2(kv)
          squaredDistance += score * score
          kv += 1
        }
      case (v1: LongSparseVector, DenseVector(vv2)) =>
        squaredDistance = sqdist(v1.indices, v1.values, vv2)
      case (v1: DenseVector, v2: LongSparseVector) =>
        squaredDistance = sqdist(v2, v1)
      case (v1: LongSparseVector, v2: IntSparseVector) =>
        squaredDistance = sqdist(v1.indices, v1.values, v2.indices, v2.values)
      case (v1: IntSparseVector, v2: LongSparseVector) =>
        squaredDistance = sqdist(v2, v1)
      case (v1: LongSparseVector, v2: LongSparseVector) =>
        squaredDistance = sqdist(v1.indices, v1.values, v2.indices, v2.values)
      case _ =>
        throw new IllegalArgumentException("Do not support vector type " + v1.getClass +
          " and " + v2.getClass)
    }
    squaredDistance
  }

  /**
    * Returns the squared distance between DenseVector and SparseVector.
    */
  def sqdist[K](v1Indices: Array[K], v1Values: Array[Double], v2Values: Array[Double]): Double = {
    var kv1 = 0
    var kv2 = 0

    var squaredDistance = 0.0
    val nnzv1 = v1Indices.length
    val nnzv2 = v2Values.length
    var iv1 = if (nnzv1 > 0) v1Indices(kv1) else -1

    while (kv2 < nnzv2) {
      var score = 0.0
      if (kv2 != iv1) {
        score = v2Values(kv2)
      } else {
        score = v1Values(kv1) - v2Values(kv2)
        if (kv1 < nnzv1 - 1) {
          kv1 += 1
          iv1 = v1Indices(kv1)
        }
      }
      squaredDistance += score * score
      kv2 += 1
    }
    squaredDistance
  }

  def sqdist[K1 <% Ordered[K1], K2 <% Ordered[K2]](v1Indices: Array[K1], v1Values: Array[Double],
                                                   v2Indices: Array[K2], v2Values: Array[Double]): Double = {
    val nnzv1 = v1Indices.length
    val nnzv2 = v2Indices.length

    implicit def K2toK1(value: K2): K1 = value.asInstanceOf[K1]

    var kv1 = 0
    var kv2 = 0
    var squaredDistance = 0.0
    while (kv1 < nnzv1 || kv2 < nnzv2) {
      var score = 0.0

      if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) {
        score = v1Values(kv1)
        kv1 += 1
      } else if (kv1 >= nnzv1 || (kv2 < nnzv2 && v1Indices(kv1) > v2Indices(kv2))) {
        score = v2Values(kv2)
        kv2 += 1
      } else {
        score = v1Values(kv1) - v2Values(kv2)
        kv1 += 1
        kv2 += 1
      }
      squaredDistance += score * score
    }

    squaredDistance
  }


  /**
    * Check equality between sparse/dense vectors
    */
  def equals[K1, K2](
                      v1Indices: IndexedSeq[K1],
                      v1Values: Array[Double],
                      v2Indices: IndexedSeq[K2],
                      v2Values: Array[Double]): Boolean = {
    val v1Size = v1Values.length
    val v2Size = v2Values.length

    var k1 = 0
    var k2 = 0
    var allEqual = true
    while (allEqual) {
      while (k1 < v1Size && v1Values(k1) == 0) k1 += 1
      while (k2 < v2Size && v2Values(k2) == 0) k2 += 1

      if (k1 >= v1Size || k2 >= v2Size) {
        return k1 >= v1Size && k2 >= v2Size // check end alignment
      }
      allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2)
      k1 += 1
      k2 += 1
    }
    allEqual
  }

  /** Max number of nonzero entries used in computing hash code. */
  private[linalg] val MAX_HASH_NNZ = 128
}

/**
  * A dense vector represented by a value array.
  */
class DenseVector(val values: Array[Double]) extends Vector with Serializable {

  override def size: Long = values.length

  override def toString: String = values.mkString("[", ",", "]")

  override def toArray: Array[Double] = values

  override def asBreeze: BV[Double] = new BDV[Double](values)

  override def apply(i: Long): Double = {
    assert(i >= 0 && i < values.length)
    values(i.toInt)
  }

  override def copy: DenseVector = {
    new DenseVector(values.clone())
  }

  override def foreachActive(f: (Long, Double) => Unit): Unit = {
    var i = 0
    val localValuesSize = values.length
    val localValues = values

    while (i < localValuesSize) {
      f(i, localValues(i))
      i += 1
    }
  }

  override def equals(other: Any): Boolean = super.equals(other)

  override def hashCode(): Int = {
    var result: Int = 31 + size.toInt
    var i = 0
    val end = values.length
    var nnz = 0
    while (i < end && nnz < Vectors.MAX_HASH_NNZ) {
      val v = values(i)
      if (v != 0.0) {
        result = 31 * result + i
        val bits = java.lang.Double.doubleToLongBits(values(i))
        result = 31 * result + (bits ^ (bits >>> 32)).toInt
        nnz += 1
      }
      i += 1
    }
    result
  }

  override def numActives: Long = size

  override def numNonzeros: Long = {
    // same as values.count(_ != 0.0) but faster
    var nnz = 0
    values.foreach { v =>
      if (v != 0.0) {
        nnz += 1
      }
    }
    nnz
  }

  private[linalg] override def toSparseWithSize(nnz: Long): IntSparseVector = {
    assert(nnz < Int.MaxValue)

    val ii = new Array[Int](nnz.toInt)
    val vv = new Array[Double](nnz.toInt)
    var k = 0
    foreachActive { (i, v) =>
      if (v != 0) {
        ii(k) = i.toInt
        vv(k) = v
        k += 1
      }
    }
    new IntSparseVector(size, ii, vv)
  }

  override def toDense: DenseVector = new DenseVector(toArray)

  override def argmax: Long = {
    if (size == 0) {
      -1
    } else {
      var maxIdx = 0
      var maxValue = values(0)
      var i = 1
      while (i < size) {
        if (values(i) > maxValue) {
          maxIdx = i
          maxValue = values(i)
        }
        i += 1
      }
      maxIdx
    }
  }
}


object DenseVector {

  /** Extracts the value array from a dense vector. */

  def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values)
}


trait SparseVector extends Vector

/**
  * A sparse vector represented by an index array and a value array.
  *
  * @param size    size of the vector.
  * @param indices index array, assume to be strictly increasing.
  * @param values  value array, must have the same length as the index array.
  */
class IntSparseVector(
                       override val size: Long,
                       val indices: Array[Int],
                       val values: Array[Double]) extends SparseVector with Serializable {

  // validate the data
  {
    require(size >= 0, "The size of the requested sparse vector must be no less than 0.")
    require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
      s" indices match the dimension of the values. You provided ${indices.length} indices and " +
      s" ${values.length} values.")
    require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
      s"which exceeds the specified vector size ${size}.")

    if (indices.nonEmpty) {
      require(indices(0) >= 0, s"Found negative index: ${indices(0)}.")
    }
    var prev = -1
    indices.foreach { i =>
      require(prev < i, s"Index $i follows $prev and is not strictly increasing")
      prev = i
    }
    require(prev < size, s"Index $prev out of bounds for vector of size $size")
  }

  override def toString: String =
    s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"

  override def toArray: Array[Double] = {
    val data = new Array[Double](size.toInt)
    var i = 0
    val nnz = indices.length
    while (i < nnz) {
      data(indices(i)) = values(i)
      i += 1
    }
    data
  }

  override def copy: IntSparseVector = {
    new IntSparseVector(size, indices.clone(), values.clone())
  }

  override def asBreeze: BV[Double] = new BSV[Double](indices, values, size.toInt)

  override def foreachActive(f: (Long, Double) => Unit): Unit = {
    var i = 0
    val localValuesSize = values.length
    val localIndices = indices
    val localValues = values

    while (i < localValuesSize) {
      f(localIndices(i), localValues(i))
      i += 1
    }
  }

  override def equals(other: Any): Boolean = super.equals(other)

  override def hashCode(): Int = {
    var result: Int = 31 + size.toInt
    val end = values.length
    var k = 0
    var nnz = 0
    while (k < end && nnz < Vectors.MAX_HASH_NNZ) {
      val v = values(k)
      if (v != 0.0) {
        val i = indices(k)
        result = 31 * result + i
        val bits = java.lang.Double.doubleToLongBits(v)
        result = 31 * result + (bits ^ (bits >>> 32)).toInt
        nnz += 1
      }
      k += 1
    }
    result
  }

  override def numActives: Long = values.length

  override def numNonzeros: Long = {
    var nnz = 0
    values.foreach { v =>
      if (v != 0.0) {
        nnz += 1
      }
    }
    nnz
  }

  private[linalg] override def toSparseWithSize(nnz: Long): IntSparseVector = {
    assert(nnz < Int.MaxValue)
    if (nnz == numActives) {
      this
    } else {
      val ii = new Array[Int](nnz.toInt)
      val vv = new Array[Double](nnz.toInt)
      var k = 0
      foreachActive { (i, v) =>
        if (v != 0.0) {
          ii(k) = i.toInt
          vv(k) = v
          k += 1
        }
      }
      new IntSparseVector(size, ii, vv)
    }
  }

  override def argmax: Long = {
    if (size == 0) {
      -1
    } else if (numActives == 0) {
      0
    } else {
      // Find the max active entry.
      var maxIdx = indices(0)
      var maxValue = values(0)
      var maxJ = 0
      var j = 1
      val na = numActives
      while (j < na) {
        val v = values(j)
        if (v > maxValue) {
          maxValue = v
          maxIdx = indices(j)
          maxJ = j
        }
        j += 1
      }

      // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
      if (maxValue <= 0.0 && na < size) {
        if (maxValue == 0.0) {
          // If there exists an inactive entry before maxIdx, find it and return its index.
          if (maxJ < maxIdx) {
            var k = 0
            while (k < maxJ && indices(k) == k) {
              k += 1
            }
            maxIdx = k
          }
        } else {
          // If the max active value is negative, find and return the first inactive index.
          var k = 0
          while (k < na && indices(k) == k) {
            k += 1
          }
          maxIdx = k
        }
      }

      maxIdx
    }
  }

  /**
    * Create a slice of this vector based on the given indices.
    *
    * @param selectedIndices Unsorted list of indices into the vector.
    *                        This does NOT do bound checking.
    * @return New SparseVector with values in the order specified by the given indices.
    *
    *         NOTE: The API needs to be discussed before making this public.
    *         Also, if we have a version assuming indices are sorted, we should optimize it.
    */
  def slice(selectedIndices: Array[Int]): IntSparseVector = {
    var currentIdx = 0
    val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
      val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
      val i_v = if (iIdx >= 0) {
        Iterator((currentIdx, this.values(iIdx)))
      } else {
        Iterator()
      }
      currentIdx += 1
      i_v
    }.unzip
    new IntSparseVector(selectedIndices.length, sliceInds, sliceVals)
  }

  /**
    * Gets the value of the ith element.
    *
    * @param i index
    */
  override def apply(i: Long): Double = {
    val iIdx = java.util.Arrays.binarySearch(this.indices, i.toInt)
    if (iIdx >= 0) {
      this.values(iIdx)
    } else {
      Double.NaN
    }
  }

  /**
    * Converts this vector to a dense vector.
    */
  override def toDense: DenseVector = {
    assert(size < Int.MaxValue)

    val denseValues = new Array[Double](size.toInt)
    indices.zip(values).foreach { case (idx, value) =>
      denseValues(idx) = value
    }

    new DenseVector(denseValues)
  }
}


object IntSparseVector {

  def unapply(sv: IntSparseVector): Option[(Long, Array[Int], Array[Double])] =
    Some((sv.size, sv.indices, sv.values))
}


/**
  * A sparse vector represented by an index array and a value array.
  *
  * @param size    size of the vector.
  * @param indices index array, assume to be strictly increasing.
  * @param values  value array, must have the same length as the index array.
  */
class LongSparseVector(
                        override val size: Long,
                        val indices: Array[Long],
                        val values: Array[Double]) extends SparseVector with Serializable {

  // validate the data
  {
    require(size >= 0, "The size of the requested sparse vector must be no less than 0.")
    require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
      s" indices match the dimension of the values. You provided ${indices.length} indices and " +
      s" ${values.length} values.")
    require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
      s"which exceeds the specified vector size ${size}.")

    if (indices.nonEmpty) {
      require(indices(0) >= 0, s"Found negative index: ${indices(0)}.")
    }
    var prev = -1L
    indices.foreach { i =>
      require(prev < i, s"Index $i follows $prev and is not strictly increasing")
      prev = i
    }
    require(prev < size, s"Index $prev out of bounds for vector of size $size")
  }

  override def toString: String =
    s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"

  override def toArray: Array[Double] = {
    if (size < Int.MaxValue) {
      val data = new Array[Double](size.toInt)
      var i = 0
      val nnz = indices.length
      while (i < nnz) {
        data(indices(i).toInt) = values(i)
        i += 1
      }
      data
    } else {
      throw new Exception("The LongSparseVector is too large to convert to Array")
    }

  }

  override def copy: LongSparseVector = {
    new LongSparseVector(size, indices.clone(), values.clone())
  }

  override def asBreeze: BV[Double] = {
    if (size < Int.MaxValue) {
      new BSV[Double](indices.map(_.toInt), values, size.toInt)
    } else {
      throw new Exception("The LongSparseVector is too large to convert to Breeze Vector")
    }
  }

  override def foreachActive(f: (Long, Double) => Unit): Unit = {
    var i = 0
    val localValuesSize = values.length
    val localIndices = indices
    val localValues = values

    while (i < localValuesSize) {
      f(localIndices(i), localValues(i))
      i += 1
    }
  }

  override def equals(other: Any): Boolean = super.equals(other)

  override def hashCode(): Int = {
    var result: Int = 31 + size.toInt
    val end = values.length
    var k = 0
    var nnz = 0
    while (k < end && nnz < Vectors.MAX_HASH_NNZ) {
      val v = values(k)
      if (v != 0.0) {
        val i = indices(k).toInt
        result = 31 * result + i
        val bits = java.lang.Double.doubleToLongBits(v)
        result = 31 * result + (bits ^ (bits >>> 32)).toInt
        nnz += 1
      }
      k += 1
    }
    result
  }

  override def numActives: Long = values.length

  override def numNonzeros: Long = {
    var nnz = 0
    values.foreach { v =>
      if (v != 0.0) {
        nnz += 1
      }
    }
    nnz
  }

  private[linalg] override def toSparseWithSize(nnz: Long): LongSparseVector = {
    assert(nnz < Int.MaxValue)

    if (nnz == numActives) {
      this
    } else {
      val ii = new Array[Long](nnz.toInt)
      val vv = new Array[Double](nnz.toInt)
      var k = 0
      foreachActive { (i, v) =>
        if (v != 0.0) {
          ii(k) = i
          vv(k) = v
          k += 1
        }
      }

      new LongSparseVector(size, ii, vv)
    }
  }

  override def argmax: Long = {
    if (size == 0) {
      -1
    } else if (numActives == 0) {
      0
    } else {
      // Find the max active entry.
      var maxIdx = indices(0)
      var maxValue = values(0)
      var maxJ = 0
      var j = 1
      val na = numActives
      while (j < na) {
        val v = values(j)
        if (v > maxValue) {
          maxValue = v
          maxIdx = indices(j)
          maxJ = j
        }
        j += 1
      }

      // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
      if (maxValue <= 0.0 && na < size) {
        if (maxValue == 0.0) {
          // If there exists an inactive entry before maxIdx, find it and return its index.
          if (maxJ < maxIdx) {
            var k = 0
            while (k < maxJ && indices(k) == k) {
              k += 1
            }
            maxIdx = k
          }
        } else {
          // If the max active value is negative, find and return the first inactive index.
          var k = 0
          while (k < na && indices(k) == k) {
            k += 1
          }
          maxIdx = k
        }
      }

      maxIdx
    }
  }

  /**
    * Create a slice of this vector based on the given indices.
    *
    * @param selectedIndices Unsorted list of indices into the vector.
    *                        This does NOT do bound checking.
    * @return New SparseVector with values in the order specified by the given indices.
    *
    *         NOTE: The API needs to be discussed before making this public.
    *         Also, if we have a version assuming indices are sorted, we should optimize it.
    */
  def slice(selectedIndices: Array[Long]): LongSparseVector = {
    var currentIdx = 0L
    val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
      val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
      val i_v = if (iIdx >= 0) {
        Iterator((currentIdx, this.values(iIdx)))
      } else {
        Iterator()
      }
      currentIdx += 1
      i_v
    }.unzip
    new LongSparseVector(selectedIndices.length, sliceInds, sliceVals)
  }

  override def apply(i: Long): Double = {
    val iIdx = java.util.Arrays.binarySearch(this.indices, i.toInt)
    if (iIdx >= 0) {
      this.values(iIdx)
    } else {
      Double.NaN
    }
  }

  override def toDense: DenseVector = {
    if (size < Int.MaxValue) {
      new DenseVector(toArray)
    } else {
      throw new Exception("The LongSparseVector is too large to convert to Array")
    }
  }
}


object LongSparseVector {
  def unapply(sv: LongSparseVector): Option[(Long, Array[Long], Array[Double])] =
    Some((sv.size, sv.indices, sv.values))
}